"""
Interaction head and its submodules

Fred Zhang <frederic.zhang@anu.edu.au>

The Australian National University
Australian Centre for Robotic Vision
"""

import torch
import torch.nn.functional as F
import numpy as np

from torch import nn, Tensor
from typing import List
from collections import OrderedDict

from add_on.net import *

class InteractionHead(nn.Module):
    """
    Interaction head that constructs and classifies box pairs

    Parameters:
    -----------
    box_pair_predictor: nn.Module
        Module that classifies box pairs
    hidden_state_size: int      (256)
        Size of the object features
    representation_size: int    (512)
        Size of the human-object pair features
    num_channels: int           (2048)
        Number of channels in the global image features
    num_classes: int            (24)
        Number of target classes
    human_idx: int
        The index of human/person class
    object_class_to_target_class: List[list]
        The set of valid action classes for each object type
    """
    def __init__(self,
        box_pair_predictor: nn.Module,
        num_channels, 
        object_class_to_target_class: List[list],
        args, num_query=500, 
    ) -> None:
        super().__init__()
        self.device = args.device
        self.args = args
        self.whole_dec = False
        self.num_query = num_query
        self.box_pair_predictor = box_pair_predictor
        
        hidden_state_size = args.hidden_dim 
        self.hidden_state_size = hidden_state_size
        
        representation_size = args.repr_dim
        self.representation_size = representation_size

        self.num_classes = args.num_classes
        self.human_idx = args.human_idx
        
        self.object_class_to_target_class = object_class_to_target_class

        self.nheads = args.nheads

        self.Human_branch = Instance_Centric_Attention(input_num=256, hidden_num=512, num_classes=args.num_classes)
        self.Object_branch = Instance_Centric_Attention(input_num=256, hidden_num=512, num_classes=args.num_classes)


    
    def forward(self, images, resnet_features: OrderedDict, srcs, 
                detr_memory, image_shapes: Tensor, sample_wh:Tensor, 
                region_props: List[dict], masks=None, poses=None):

        device = self.device         
        boxes_h_collated = []
        boxes_o_collated = []
        prior_collated = []
        object_class_collated = []
        attn_maps_collated = []
        HOI_tokens_collated = []     
        prior_score = []

        if self.whole_dec:
            unary_attns = []
            batch_num = detr_memory.shape[1]
            unary_tokens = torch.zeros((batch_num, self.num_query, self.representation_size),
                                        dtype=torch.float, device=device)
        #   每一个sample
        for b_idx, props in enumerate(region_props):
            n = len(props['boxes'])
            box = props['boxes']  
            score = props['scores']
            label = props['labels']
            unary_token = props['hidden_states']
            
            
            is_human = (label == self.human_idx)
            n_h = torch.sum(is_human)
            # Permute human instances to the top
            if not torch.all(label==self.human_idx):
                h_idx = torch.nonzero(is_human).squeeze(1)
                o_idx = torch.nonzero(is_human == False).squeeze(1)
                perm = torch.cat([h_idx, o_idx])
                box = box[perm]
                score = score[perm]
                label = label[perm]
                unary_token = unary_token[perm]

            if n_h == 0 or n <= 1:
                boxes_h_collated.append(torch.zeros(0, device=device, dtype=torch.int64))
                boxes_o_collated.append(torch.zeros(0, device=device, dtype=torch.int64))
                object_class_collated.append(torch.zeros(0, device=device, dtype=torch.int64))
                prior_collated.append(torch.zeros(2, 0, self.num_classes, device=device))
                prior_score.append(torch.zeros(0, device=device))
            

                HOI_token = torch.zeros((0, self.num_classes), dtype=torch.float, device=device)

                HOI_tokens_collated.append(HOI_token)
                attn_maps_collated.append(torch.zeros(0, 512, detr_memory.shape[-2], detr_memory.shape[-1], device=device))
                continue  

            
            # Get the pairwise indices  (N, N)
            x, y = torch.meshgrid(torch.arange(n, device=device),
                                  torch.arange(n, device=device))
            if self.args.dataset == 'hicodet':
                x_keep, y_keep = torch.nonzero(torch.logical_and(x!=y, x < n_h)).unbind(1)
            elif self.args.dataset == 'vcoco':
                x_keep, y_keep = torch.nonzero(x < n_h).unbind(1)
            prior = self.compute_prior_scores(x_keep, y_keep, score, label)
            prior_score.append(score)
            boxes_h_collated.append(x_keep)
            boxes_o_collated.append(y_keep)
            object_class_collated.append(label[y_keep])
            prior_collated.append(prior)     

            
            ####    之后全是pair(N x N)形式
            x, y = x.flatten(), y.flatten()
            # Compute spatial features  (NxN, 36)
            Human_feat, Human_attn = self.Human_branch(unary_token[x_keep], detr_memory[b_idx:b_idx+1])
            Object_feat ,Object_attn= self.Object_branch(unary_token[y_keep], detr_memory[b_idx:b_idx+1])
            HOI_tokens_collated.append(Human_feat * Object_feat)
            attn_maps_collated.append(Human_attn.reshape(Human_attn.shape[0], Human_attn.shape[1], detr_memory.shape[-2] , detr_memory.shape[-1]))
        logits = torch.cat(HOI_tokens_collated, dim=0).unsqueeze(0)   
        return logits, prior_collated, prior_score,\
                boxes_h_collated, boxes_o_collated, \
                object_class_collated, attn_maps_collated

    def compute_prior_scores(self, x_keep: Tensor, y_keep: Tensor, scores: Tensor, object_class: Tensor) -> Tensor:
        prior_h = torch.zeros(len(x_keep), self.num_classes, device=self.device)
        prior_o = torch.zeros_like(prior_h)
        # Raise the power of object detection scores during inference
        p = 1.0 if self.training else 2.8
        s_h = scores[x_keep].pow(p)
        s_o = scores[y_keep].pow(p)
        # Map object class index to target class index(过滤矩阵滤掉不正确的pair)
        # Object class index to target class index is a one-to-many mapping
        #   Vcoco会自动不过滤没有宾语的动作
        target_cls_idx = [self.object_class_to_target_class[obj.item()] if x_keep[i] != y_keep[i]
                          else range(self.num_classes) for i, obj in enumerate(object_class[y_keep])]
        # Duplicate box pair indices for each target class
        pair_idx = [i for i, tar in enumerate(target_cls_idx) for _ in tar]
        # Flatten mapped target indices
        flat_target_idx = [t for tar in target_cls_idx for t in tar]

        prior_h[pair_idx, flat_target_idx] = s_h[pair_idx]#pairN, 24
        prior_o[pair_idx, flat_target_idx] = s_o[pair_idx]#pairN, 24

        return torch.stack([prior_h, prior_o])#2, pairN, 24
    




class Instance_Centric_Attention(nn.Module):
    def __init__(self, input_num, hidden_num, num_classes
    ) -> None:
        super().__init__()
        self.hidden_num = hidden_num
        self.Qfc = nn.Sequential(
            nn.Linear(input_num, 2048),
            nn.ReLU(inplace=True),
        )
        self.Qconv = nn.Sequential(
            nn.Linear(input_num, hidden_num),
            nn.ReLU(inplace=True),
        )
        self.Kconv = nn.Sequential(
            nn.Conv2d(input_num, hidden_num, kernel_size=1),
            nn.ReLU(inplace=True),
        )
        self.Vconv = nn.Sequential(
            nn.Conv2d(input_num, hidden_num, kernel_size=1),
            nn.ReLU(inplace=True),
        )
        self.fc =  nn.Sequential(
            nn.Linear(hidden_num, 1024),
            nn.ReLU(inplace=True),
        )
        self.inst_fc = nn.Sequential(
            nn.Linear(3072, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes),
        )
    
    def forward(self, stream, detr_memory):
        w, h = detr_memory.shape[-2],  detr_memory.shape[-1]
        fc = self.Qfc(stream)
        stream = self.Qconv(stream)
        num_q = stream.shape[0]
        context_K = self.Kconv(detr_memory).flatten(-2).squeeze()
        context_V = self.Vconv(detr_memory).flatten(-2).squeeze()
        score =  torch.einsum('xy, yz -> xyz',stream, context_K)
        score = F.softmax(score, -1)
        feat = torch.einsum('xyz, yz -> xyz',score, context_V).reshape(num_q, -1, w, h)
        feat = torch.nn.functional.adaptive_avg_pool2d(feat, (1,1)).squeeze().reshape(-1, self.hidden_num)
        feat = self.fc(feat)
        feat = torch.cat([fc, feat], -1)
        feat = self.inst_fc(feat)
        return feat, score